{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "编码: ['<bos>', 'he', 'll', 'o', '<eos>']\n",
      "解码: he ll o\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from tokenizers import Tokenizer, models, trainers, pre_tokenizers\n",
    "\n",
    "# ======================= 1️⃣  数据加载 =======================\n",
    "def load_text_data(file_paths):\n",
    "    \"\"\"从多个文本文件加载训练数据\"\"\"\n",
    "    text_data = \"\"\n",
    "    for file in file_paths:\n",
    "        with open(file, 'r', encoding='utf-8') as f:\n",
    "            text_data += f.read().lower()  # 统一小写\n",
    "    return text_data\n",
    "\n",
    "# 你可以使用多个数据集文件\n",
    "file_paths = [\"data1.txt\", \"data2.txt\"]  # 这里换成你的数据文件\n",
    "text = load_text_data(file_paths)\n",
    "\n",
    "# ======================= 2️⃣  训练 BPE Tokenizer =======================\n",
    "tokenizer = Tokenizer(models.BPE())  # 训练 BPE 词表\n",
    "trainer = trainers.BpeTrainer(special_tokens=[\"<pad>\", \"<bos>\", \"<eos>\"])\n",
    "tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()\n",
    "tokenizer.train(file_paths, trainer)\n",
    "\n",
    "# 测试 Tokenizer\n",
    "encoded = tokenizer.encode(\"hello world\")\n",
    "print(\"编码:\", encoded.tokens)\n",
    "print(\"解码:\", tokenizer.decode(encoded.ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ======================= 3️⃣  构建 Transformer 解码器 =======================\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, d_model, num_heads, num_layers, dim_feedforward, max_len):\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, d_model)\n",
    "        self.positional_encoding = nn.Parameter(torch.randn(1, max_len, d_model))\n",
    "        \n",
    "        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, dim_feedforward)\n",
    "        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)\n",
    "        self.fc_out = nn.Linear(d_model, vocab_size)\n",
    "\n",
    "    def forward(self, input_seq, memory):\n",
    "        seq_len = input_seq.size(1)\n",
    "        embedded = self.embedding(input_seq) + self.positional_encoding[:, :seq_len, :]\n",
    "        memory = self.embedding(memory) + self.positional_encoding[:, :memory.size(1), :]\n",
    "        output = self.transformer_decoder(embedded, memory)\n",
    "        return self.fc_out(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练样本: torch.Size([603, 50]) torch.Size([603, 50])\n"
     ]
    }
   ],
   "source": [
    "# 设置超参数\n",
    "vocab_size = tokenizer.get_vocab_size()\n",
    "d_model = 64  \n",
    "num_heads = 4  \n",
    "num_layers = 3  \n",
    "dim_feedforward = 128  \n",
    "max_len = 200  \n",
    "\n",
    "model = TransformerDecoder(vocab_size, d_model, num_heads, num_layers, dim_feedforward, max_len)\n",
    "\n",
    "# ======================= 4️⃣  训练数据准备 =======================\n",
    "def create_training_data(text, tokenizer, seq_length=50):\n",
    "    token_ids = tokenizer.encode(text).ids\n",
    "    inputs, targets = [], []\n",
    "\n",
    "    for i in range(len(token_ids) - seq_length):\n",
    "        inputs.append(token_ids[i:i+seq_length])\n",
    "        targets.append(token_ids[i+1:i+seq_length+1])\n",
    "\n",
    "    return torch.tensor(inputs), torch.tensor(targets)\n",
    "\n",
    "# 生成训练数据\n",
    "inputs, targets = create_training_data(text, tokenizer)\n",
    "print(\"训练样本:\", inputs.shape, targets.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss: 5.30748176574707\n",
      "Epoch 1, Loss: 4.808955669403076\n",
      "Epoch 2, Loss: 3.4255547523498535\n",
      "Epoch 3, Loss: 2.2905361652374268\n",
      "Epoch 4, Loss: 1.5883567333221436\n",
      "Epoch 5, Loss: 1.3537038564682007\n",
      "Epoch 6, Loss: 1.3528802394866943\n",
      "Epoch 7, Loss: 1.20084810256958\n",
      "Epoch 8, Loss: 1.205197811126709\n",
      "Epoch 9, Loss: 1.1740862131118774\n",
      "Epoch 10, Loss: 1.1421250104904175\n",
      "Epoch 11, Loss: 1.1442739963531494\n",
      "Epoch 12, Loss: 1.1199126243591309\n",
      "Epoch 13, Loss: 1.1381272077560425\n",
      "Epoch 14, Loss: 1.1888636350631714\n",
      "Epoch 15, Loss: 1.0965999364852905\n",
      "Epoch 16, Loss: 1.2028051614761353\n",
      "Epoch 17, Loss: 1.1394095420837402\n",
      "Epoch 18, Loss: 1.1190690994262695\n",
      "Epoch 19, Loss: 1.1053941249847412\n"
     ]
    }
   ],
   "source": [
    "# 创建 PyTorch 数据加载器\n",
    "dataset = TensorDataset(inputs, targets)\n",
    "dataloader = DataLoader(dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "# ======================= 5️⃣  训练模型 =======================\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.95)\n",
    "\n",
    "def train(model, dataloader, epochs=20):\n",
    "    for epoch in range(epochs):\n",
    "        for batch_inputs, batch_targets in dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            output = model(batch_inputs, batch_inputs)\n",
    "            loss = nn.CrossEntropyLoss()(output.view(-1, vocab_size), batch_targets.view(-1))\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # 避免梯度爆炸\n",
    "            optimizer.step()\n",
    "        scheduler.step()\n",
    "        print(f\"Epoch {epoch}, Loss: {loss.item()}\")\n",
    "\n",
    "# 开始训练\n",
    "train(model, dataloader, epochs=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "生成文本: he ll o the president t , the president that in the president that the president that the president that the president that in the president that the president that the president said . “ that the president that the president that in on the u k rain e u k a t , so that the u this is . u dz ha , with with the general , with that in u , that the president and told t a m had with u , that in the region in the r that r ussi an s , and m and\n"
     ]
    }
   ],
   "source": [
    "# ======================= 6️⃣  文本生成（Greedy 解码） =======================\n",
    "def generate_text(model, tokenizer, start_text, max_length=100):\n",
    "    model.eval()\n",
    "    generated = tokenizer.encode(start_text).ids\n",
    "    input_tensor = torch.tensor([generated])\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for _ in range(max_length):\n",
    "            output = model(input_tensor, input_tensor)\n",
    "            next_token = output.argmax(dim=-1)[:, -1].item()\n",
    "            generated.append(next_token)\n",
    "            input_tensor = torch.tensor([generated])\n",
    "\n",
    "    return tokenizer.decode(generated)\n",
    "\n",
    "# 测试生成\n",
    "print(\"生成文本:\", generate_text(model, tokenizer, \"hello\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
